Skip to content

Splash attention#67

Merged
pythoncrazy merged 4 commits intomasterfrom
splash_attention
Jan 19, 2026
Merged

Splash attention#67
pythoncrazy merged 4 commits intomasterfrom
splash_attention

Conversation

@pythoncrazy
Copy link
Copy Markdown
Owner

This pull request introduces experimental support for TPU-optimized Splash Attention using the tokamax library. It adds a new configuration class, integrates Splash Attention into the model codebase, and updates documentation and dependencies accordingly. The changes allow users to enable Splash Attention for improved performance on supported hardware (e.g., TPU), and provide a flexible API for configuration.

Splash Attention Integration:

  • Added SplashAttentionConfig dataclass and kernel creation utilities in src/jimm/common/splash_attention.py to encapsulate configuration and provide a factory for Splash Attention functions.
  • Integrated Splash Attention into the transformer and vision transformer modules by adding a splash_attention_config parameter and wiring the custom attention function into nnx.MultiHeadAttention when enabled. [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] [11] [12] [13] [14]

Model and API Updates:

  • Exposed SplashAttentionConfig in the package API (src/jimm/__init__.py) for easy user access. [1] [2]
  • Updated CLIP model and from_pretrained methods to accept and propagate splash_attention_config. [1] [2] [3] [4] [5]

Documentation and Dependency Updates:

  • Added a new section in README.md documenting Splash Attention usage, configuration, and hardware caveats.
  • Declared tokamax>=0.0.9 as a dependency and updated the minimum Python version to 3.11 in pyproject.toml.

These changes enable experimental, hardware-optimized attention for users on supported devices, while maintaining backward compatibility and clear configuration options.

- Add SplashAttentionConfig for configuring splash attention parameters
- Integrate splash attention into Transformer, TransformerEncoder, and VisionTransformerBase
- Update CLIP, SigLIP, and VisionTransformer models to accept splash_attention_config
- Add tokamax>=0.0.9 dependency and update Python requirement to >=3.11
- Add tests for splash attention configuration
- Update README with usage documentation
…ention and regular attention models, updated README
@pythoncrazy
Copy link
Copy Markdown
Owner Author

/gemini

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @pythoncrazy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the jimm library by introducing experimental support for TPU-optimized Splash Attention. This integration, powered by the tokamax library, aims to provide substantial performance gains for models running on compatible hardware. The changes include a new, flexible configuration system for Splash Attention, seamless integration into core transformer and vision transformer architectures, and necessary updates to dependencies and documentation to guide users in leveraging this new capability.

Highlights

  • TPU-Optimized Splash Attention: Introduced experimental support for Splash Attention, leveraging the tokamax library for enhanced performance on TPUs.
  • Configuration API: Added a new SplashAttentionConfig dataclass and associated utilities to manage Splash Attention settings, including mask type and block sizes.
  • Model Integration: Integrated Splash Attention into Transformer and VisionTransformer modules, allowing models like CLIP and SigLIP to utilize this new attention mechanism.
  • Public API Exposure: Exposed SplashAttentionConfig in the main jimm package API for easy user access and configuration.
  • Dependency and Environment Updates: Added tokamax>=0.0.9 as a new dependency and updated the minimum required Python version to 3.11.
  • Documentation: Included a new section in the README.md detailing how to use and configure Splash Attention, along with important hardware caveats.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@pythoncrazy
Copy link
Copy Markdown
Owner Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces experimental support for TPU-optimized Splash Attention via the tokamax library. The changes are well-structured, adding a new configuration class and integrating it throughout the model hierarchy. The addition of documentation and tests is also a great practice. My review focuses on improving the consistency of the new create_splash_attention_fn API and simplifying some of the integration logic for better maintainability.

Comment on lines +86 to +101
) -> Callable[..., Float[Array, "batch heads seq head_dim"]]:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.

Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.

Returns:
Callable: An attention function. Returns splash attention if enabled and available,
otherwise returns the default dot_product_attention.
"""
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention

return dot_product_attention
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function create_splash_attention_fn currently returns dot_product_attention when Splash Attention is disabled or unavailable. This is inconsistent with the new test test_create_fn_returns_none_when_disabled in tests/test_splash_attention.py, which expects None. Returning None would make the function's behavior more explicit, align with the test's expectation, and simplify the logic at call sites.

I suggest changing the implementation to return None when splash attention is not used. You'll also need to update the function's return type hint and docstring. This change will make the test test_create_fn_returns_none_when_disabled pass as written.

Suggested change
) -> Callable[..., Float[Array, "batch heads seq head_dim"]]:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.
Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
Returns:
Callable: An attention function. Returns splash attention if enabled and available,
otherwise returns the default dot_product_attention.
"""
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention
return dot_product_attention
) -> Callable[..., Float[Array, "batch heads seq head_dim"]] | None:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.
Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
Returns:
Callable | None: An attention function. Returns splash attention if enabled and available,
otherwise returns None.
"""
if not _TOKAMAX_AVAILABLE or not config.enabled:
return None

Comment on lines +47 to +74
attention_fn = None
if splash_attention_config is not None:
attention_fn = create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
)

attn_kwargs: dict = {
"num_heads": num_heads,
"in_features": hidden_size,
"broadcast_dropout": False,
"decode": False,
"deterministic": False,
"dtype": dtype,
"param_dtype": param_dtype,
"rngs": rngs,
"kernel_init": nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("map_attn_in", "map_attn_out")),
"bias_init": nnx.with_partitioning(
nnx.initializers.zeros_init(),
mesh_rules(
"map_attn_out",
),
),
)
}
if attention_fn is not None:
attn_kwargs["attention_fn"] = attention_fn
self.attn = nnx.MultiHeadAttention(**attn_kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization of self.attn can be simplified. Instead of creating a dictionary attn_kwargs and conditionally adding attention_fn, you can determine attention_fn first and then pass it directly to the nnx.MultiHeadAttention constructor. Since nnx.MultiHeadAttention handles attention_fn=None by using the default attention, the logic can be made more concise and readable.

Assuming the suggested change in splash_attention.py is made (where create_splash_attention_fn returns None if disabled), you can refactor this block to be cleaner.

        attention_fn = None
        if splash_attention_config is not None:
            attention_fn = create_splash_attention_fn(
                splash_attention_config,
                num_heads=num_heads,
                head_dim=hidden_size // num_heads,
            )

        self.attn = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=hidden_size,
            broadcast_dropout=False,
            decode=False,
            deterministic=False,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("map_attn_in", "map_attn_out")),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(),
                mesh_rules(
                    "map_attn_out",
                ),
            ),
            attention_fn=attention_fn,
        )

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Acknowledged. I am now performing a full code review of the pull request. I will post my findings shortly.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces experimental support for Splash Attention, which is a great addition for performance on TPUs. The changes are well-structured, adding a new configuration and integrating it throughout the model hierarchy. I've found a critical issue in the implementation of the create_splash_attention_fn which would cause a new test to fail, along with a related inconsistency in how the attention function is handled in different models. My review includes suggestions to fix these issues to ensure correctness and improve maintainability.

Comment on lines +98 to +101
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention

return dot_product_attention
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function should return None when Splash Attention is disabled or unavailable. This allows nnx.MultiHeadAttention to fall back to its default dot_product_attention. The current implementation returns dot_product_attention directly, which is inconsistent with the new test test_create_fn_returns_none_when_disabled and can lead to less clear control flow in the model definitions.

Suggested change
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention
return dot_product_attention
if not _TOKAMAX_AVAILABLE or not config.enabled:
return None

Comment on lines +86 to +97
) -> Callable[..., Float[Array, "batch heads seq head_dim"]]:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.

Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.

Returns:
Callable: An attention function. Returns splash attention if enabled and available,
otherwise returns the default dot_product_attention.
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint and docstring for create_splash_attention_fn should be updated to reflect that it can return None when splash attention is disabled. This makes the function signature more accurate and consistent with the proposed implementation change.

Suggested change
) -> Callable[..., Float[Array, "batch heads seq head_dim"]]:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.
Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
Returns:
Callable: An attention function. Returns splash attention if enabled and available,
otherwise returns the default dot_product_attention.
"""
) -> Callable[..., Float[Array, "batch heads seq head_dim"]] | None:
"""Create a splash attention function compatible with nnx.MultiHeadAttention.
Args:
config (SplashAttentionConfig): Splash attention configuration.
num_heads (int): Number of attention heads.
head_dim (int): Dimension of each attention head.
Returns:
Callable | None: An attention function if splash attention is enabled and available,
otherwise `None`.
"""

Comment on lines +73 to +81
attention_fn = (
create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
)
if splash_attention_config is not None
else nnx.dot_product_attention
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining attention_fn can be simplified and made more consistent with the pattern used in src/jimm/common/vit.py. By initializing attention_fn to None and only creating the splash function if splash_attention_config is provided, the code becomes more explicit and readable. nnx.MultiHeadAttention will correctly use its default when attention_fn is None.

This change depends on create_splash_attention_fn returning None when disabled, as suggested in another comment.

        attention_fn = None
        if splash_attention_config is not None:
            attention_fn = create_splash_attention_fn(
                splash_attention_config,
                num_heads=num_heads,
                head_dim=hidden_size // num_heads,
            )

@pythoncrazy
Copy link
Copy Markdown
Owner Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces experimental support for TPU-optimized Splash Attention using the tokamax library. It adds a new configuration class, integrates Splash Attention into the model codebase, and updates documentation and dependencies accordingly. The changes are well-structured, with new functionality encapsulated in src/jimm/common/splash_attention.py. The integration into existing models is done by plumbing the new configuration through the model hierarchy, which is a clean approach. New tests are added to ensure the changes don't break existing functionality when Splash Attention is disabled. My review focuses on improving code clarity and maintainability in the new splash_attention module and related components. I've suggested simplifying some logic and improving naming for better readability. Overall, this is a great addition to the library.

Comment on lines +98 to +101
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention

return dot_product_attention
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import of dot_product_attention is performed inside this function. It's generally better practice to place all imports at the top of the file for improved readability and to avoid repeated import operations. Since flax is a core dependency of the project, you can add from flax.nnx.nn.attention import dot_product_attention at the top of the file and simplify this block.

Suggested change
if not _TOKAMAX_AVAILABLE or not config.enabled:
from flax.nnx.nn.attention import dot_product_attention
return dot_product_attention
if not _TOKAMAX_AVAILABLE or not config.enabled:
return dot_product_attention

"""
seq_len = query.shape[2]
kernel = _create_splash_kernel(seq_len, num_heads, head_dim, config)
return jax.vmap(lambda q, k, v: kernel(q, k, v))(query, key, value)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The lambda function used here is redundant. jax.vmap can directly operate on the kernel function, which makes the code more concise and easier to read.

Suggested change
return jax.vmap(lambda q, k, v: kernel(q, k, v))(query, key, value)
return jax.vmap(kernel)(query, key, value)

Comment on lines +47 to +74
attention_fn = None
if splash_attention_config is not None:
attention_fn = create_splash_attention_fn(
splash_attention_config,
num_heads=num_heads,
head_dim=hidden_size // num_heads,
)

attn_kwargs: dict = {
"num_heads": num_heads,
"in_features": hidden_size,
"broadcast_dropout": False,
"decode": False,
"deterministic": False,
"dtype": dtype,
"param_dtype": param_dtype,
"rngs": rngs,
"kernel_init": nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("map_attn_in", "map_attn_out")),
"bias_init": nnx.with_partitioning(
nnx.initializers.zeros_init(),
mesh_rules(
"map_attn_out",
),
),
)
}
if attention_fn is not None:
attn_kwargs["attention_fn"] = attention_fn
self.attn = nnx.MultiHeadAttention(**attn_kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for creating and assigning the attention_fn can be simplified. You can directly create and add the attention_fn to the attn_kwargs dictionary within a single conditional block, which removes the need for the intermediate attention_fn variable and a second if statement. This makes the code more direct and readable.

        attn_kwargs: dict = {
            "num_heads": num_heads,
            "in_features": hidden_size,
            "broadcast_dropout": False,
            "decode": False,
            "deterministic": False,
            "dtype": dtype,
            "param_dtype": param_dtype,
            "rngs": rngs,
            "kernel_init": nnx.with_partitioning(nnx.initializers.xavier_uniform(), mesh_rules("map_attn_in", "map_attn_out")),
            "bias_init": nnx.with_partitioning(
                nnx.initializers.zeros_init(),
                mesh_rules(
                    "map_attn_out",
                ),
            ),
        }
        if splash_attention_config is not None:
            attn_kwargs["attention_fn"] = create_splash_attention_fn(
                splash_attention_config,
                num_heads=num_heads,
                head_dim=hidden_size // num_heads,
            )
        self.attn = nnx.MultiHeadAttention(**attn_kwargs)

assert config.block_kv == 256


def test_create_fn_returns_none_when_disabled() -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The name of this test function, test_create_fn_returns_none_when_disabled, is a bit misleading as the function doesn't return None but rather the default dot_product_attention function. Renaming it would improve clarity.

Suggested change
def test_create_fn_returns_none_when_disabled() -> None:
def test_create_fn_returns_default_when_disabled() -> None:

…d the test function. Also simplified the splash attention fn
@pythoncrazy pythoncrazy merged commit e496f89 into master Jan 19, 2026
1 check failed
pythoncrazy added a commit that referenced this pull request Jan 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant